<!DOCTYPE html>
<html lang="en-US">
<head>
    <meta charset="utf-8"/>
    <meta content="IE=edge" http-equiv="X-UA-Compatible"/>
    <meta content="width=device-width, initial-scale=1" name="viewport"/>
    <link href="/assets/css/style.css" rel="stylesheet">
    </link>
</head>
</html>
<body>
<div class="wrapper">
    <header>
        <div style="display: flex; align-items: center; height: 100%; margin-bottom: 5px;">
            <img src="/assets/img/backpack_logo_torch.svg"
                 style="width: 48px; margin-right:10px">
            <h1  style="margin-top:auto; margin-bottom:auto; display:block">
                <a href="https://f-dangel.github.io/backpack/">
                    BackPACK
                </a>
            </h1>
        </div>
        <p>
            Get more out of your backward pass
        </p>
        <p class="view">
        <ul>
            <li>
                <a href="examples.html">
                    BackPACK on a small example
                </a>
            </li>
        </ul>
        </p>
        <p class="view">
        <ul>
            <li>
                <a href="https://docs.backpack.pt">
                    Documentation
                </a>
            </li>
        </ul>
        </p>
        <p class="view">
        <ul>
            <li>
                <a href="https://github.com/f-dangel/backpack">
                    Github repo
                </a>
            </li>
        </ul>
        </p>
    </header>
    <section>
        <p>BackPACK is a library built on top of <a href="https://pytorch.org/">PyTorch</a>
to make it easy to extract more information from a backward pass.
Some of the things you can compute:</p>
<center>
    <ul class="downloads">
        <li id="gradientItem" class="active">
            <a href="" id="gradientButton">
                the gradient with
                <strong>
                    PyTorch
                </strong>
            </a>
        </li>
        <li id="varianceItem">
            <a href="" id="varianceButton">
                an estimate of the
                <strong>
                    Variance
                </strong>
            </a>
        </li>
        <li id="diagGGNItem">
            <a href="" id="diagGGNButton">
                the Gauss-Newton
                <strong>
                    Diagonal
                </strong>
            </a>
        </li>
        <li id="KFACItem">
            <a href="" id="KFACButton">
                the Gauss-Newton
                <strong>
                    KFAC
                </strong>
            </a>
        </li>
    </ul>
</center>

<!--- GRADIENT CODE --->
<div id="gradientCode">
<div class="language-python highlighter-rouge">
    <div class="highlight">
        <pre class="highlight"><code><span class="s">"""
Compute the gradient with Pytorch

"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>


<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n">CrossEntropyLoss</span><span class="p">()</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>


<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">)</span>

</code></pre>
    </div>
</div>
</div>

<!--- VARIANCE CODE --->
<div id="varianceCode" style="display:none;">
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Compute the gradient with Pytorch
and the variance with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">Variance</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>

<span style="color: blue;"><span class="k">with</span> <span class="n">backpack</span><span class="p">(</span><span class="n">Variance</span><span class="p">()):</span></span>
    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">)</span>
    <span style="color: blue;"><span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">variance</span><span class="p">)</span></span>
</code></pre></div></div></div>

<!--- SECOND MOMENT CODE --->
<div id="secondMomentCode" style="display:none;">
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Compute the gradient with Pytorch
and the second moment with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">SumGradSquared</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>

<span style="color: blue;"><span class="k">with</span> <span class="n">backpack</span><span class="p">(</span><span class="n">SumGradSquared</span><span class="p">()):</span></span>
    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">)</span>
    <span style="color: blue;"><span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">sum_grad_squared</span><span class="p">)</span></span>
</code></pre></div></div></div>

<!--- DIAGGGN CODE --->
<div id="diagGGNCode" style="display:none;">
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Compute the gradient with Pytorch
and the diagonal of the Gauss-Newton with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">DiagGGNExact</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>

<span style="color: blue;"><span class="k">with</span> <span class="n">backpack</span><span class="p">(</span><span class="n">DiagGGNExact</span><span class="p">()):</span></span>
    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">)</span>
    <span style="color: blue;"><span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">diag_ggn_exact</span><span class="p">)</span></span>
</code></pre></div></div></div>

<!--- KFAC CODE --->
<div id="KFACCode" style="display:none;">
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Compute the gradient with Pytorch
and KFAC with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">KFAC</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>

<span style="color: blue;"><span class="k">with</span> <span class="n">backpack</span><span class="p">(</span><span class="n">KFAC</span><span class="p">()):</span></span>
    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">)</span>
    <span style="color: blue;"><span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">kfac</span><span class="p">)</span></span>
</code></pre></div></div>
</div>

<script>
    window.onload = function() {
        var varianceButton = document.getElementById("varianceButton");  
        var gradientButton = document.getElementById("gradientButton");  
        var diagGGNButton = document.getElementById("diagGGNButton");  
        var KFACButton = document.getElementById("KFACButton");

        var varianceItem = document.getElementById("varianceItem");  
        var gradientItem = document.getElementById("gradientItem");  
        var diagGGNItem = document.getElementById("diagGGNItem");  
        var KFACItem = document.getElementById("KFACItem");  

        var gradientCode = document.getElementById("gradientCode");  
        var varianceCode = document.getElementById("varianceCode");  
        var diagGGNCode = document.getElementById("diagGGNCode");  
        var KFACCode = document.getElementById("KFACCode");  

        var getCurrent = function() {
            if (varianceItem.className === "active") {
                return "variance"
            } else if (gradientItem.className === "active") {
                return "gradient"
            } else if (diagGGNItem.className === "active") {
                return "diagGGN"
            } else if (KFACItem.className === "active") {
                return "KFAC"
            }
h        }

        var setNewCurrent = function(name) {
            varianceItem.className = ""
            gradientItem.className = ""
            diagGGNItem.className = ""
            KFACItem.className = ""

            gradientCode.style.display = "none";
            varianceCode.style.display = "none";
            diagGGNCode.style.display = "none";
            KFACCode.style.display = "none";

            if (name === "variance") {
                varianceItem.className = "active"
                varianceCode.style.display = "block";
            } else if (name === "gradient") {
                gradientItem.className = "active"
                gradientCode.style.display = "block";
            } else if (name === "diagGGN") {
                diagGGNItem.className = "active"
                diagGGNCode.style.display = "block";
            } else if (name === "KFAC") {
                KFACItem.className = "active"
                KFACCode.style.display = "block";
            } else {
                gradientCode.style.display = "block";
            }
        }

        gradientButton.onclick = function() {
            setNewCurrent("gradient")
            return false
        }
        varianceButton.onclick = function() {
            setNewCurrent("variance")
            return false
        }
        diagGGNButton.onclick = function() {
            setNewCurrent("diagGGN")
            return false
        }
        KFACButton.onclick = function() {
            setNewCurrent("KFAC")
            return false
        }
    }
</script>

<hr />

<p><strong>Install with</strong></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>pip install backpack-for-pytorch 
</code></pre></div></div>

<hr />

<p>If you use BackPACK in your research, please cite <float style="float:right"><a href="/assets/dangel2020backpack.bib">download bibtex</a></float></p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@inproceedings{dangel2020backpack,
    title = {BackPACK: Packing more into Backprop},
    author = {Felix Dangel and Frederik Kunstner and Philipp Hennig},
    booktitle = {International Conference on Learning Representations},
    year = {2020},
    url = {https://openreview.net/forum?id=BJlrF24twB}
}
</code></pre></div></div>


    </section>
    <footer>
        <p>
            <small>
                Hosted on GitHub Pages — Theme by
                <a href="https://github.com/orderedlist">
                    orderedlist
                </a>
            </small>
        </p>
    </footer>
</div>
<script src="/assets/js/scale.fix.js">
</script>
</body>
